import pickle

import ipdb
st = ipdb.set_trace

import re


def utterance2program_bdetr(utterance, cliport=True):
    program = [
        {'op': 'detect_objects', 'inputs': []}
    ]

    if utterance.startswith('align'):
        program += [
            {"op": "filter", "concept": [utterance.split()[2], 'none'], "inputs": [0]},
            {"op": "filter_frame", "inputs": [0]},
        ]
        directions = utterance.split('from')[1].split('to')
        program.append(
            {'op': 'align', 'concept': [directions[0].strip(), directions[1].strip()], 'inputs': [1, 2]}
        )
    elif utterance.startswith('put') and cliport:
        objects = re.split(' in | on ', utterance)
        objects[1] = ' '.join(objects[1].split()[1:])
        loc = "none"
        obj2 = objects[1]
        
        program += [
            {"op": "filter", "concept": [' '.join(objects[0].split()[2:]), 'none'], "inputs": [0]},
            {"op": "filter", "concept": [obj2.strip(), loc], "inputs": [0]},
            {'op': 'binaryEBM', "concept": ["inside", "false"], 'inputs': [1, 2]}
        ]
    elif ' and ' in utterance and ' to the ' in utterance:
        subsubutt = utterance.split(' and ')
        subsubutt = [s.strip() for s in subsubutt]

        # extract objs and relations
        programs = []
        for utt_ in subsubutt:
            program_ = utterance2program_bdetr(utt_, cliport=False)
            programs.append(program_)

        return merge_programs(programs)

    elif utterance.startswith('put') and not cliport:
        subutt = re.split(' inside | on | above | below | to', utterance)
        subutt[0] = ' '.join(subutt[0].split()[2:])
        if 'that are' in subutt[0]:
            relate_utt = subutt[0].split('than')
            relation = relate_utt[0].split()[-1]
            obj2 = ' '.join(relate_utt[1].split()[1:])
            program += [
                {"op": "filter", "concept": [obj2.strip(), 'none'], "inputs": [0]},
                {"op": 'relate_compare', "concept": [relation.strip(), "height"], "inputs": [0, 1]}
            ]
            depth = 2
        else:
            program += [
                {"op": "filter", "concept": [subutt[0].strip(), 'none'], "inputs": [0]},
            ]
            depth = 1
        rel_phr = utterance.split(subutt[0].strip())[1].strip()
        if rel_phr.startswith(('on', 'above', 'below', 'to')) and 'and' not in rel_phr:
            rel_obj = rel_phr.split('the')[-1]
            rel = ' '.join(rel_phr.split('the')[:-1])
            if 'left' in rel:
                relation = 'left'
            elif 'right' in rel:
                relation = 'right'
            elif 'above' in rel:
                relation  = 'above'
            elif 'below' in rel:
                relation = 'below'
            else:
                assert False, rel
            program += [
                {"op": "filter", "concept": [rel_obj.strip(), 'none'], "inputs": [0]},
                {"op": "binaryEBM", "concept": [relation, "false"], "inputs": [depth, depth+1]},
            ]
            depth += 2
                
        else:
            obj = rel_phr.split('the')[-1]
            program += [
                {"op": "filter", "concept": [obj.strip(), 'none'], "inputs": [0]},
            ]
            depth += 1
            program += [
                {"op": "binaryEBM", "concept": ["inside", "false"], "inputs": [depth-1, depth]},
            ]
            depth += 1

    elif utterance.startswith('pack'):
        objects = re.split(' inside | into | in ', utterance)
        if objects[0].startswith('pack all'):
            obj0 = ' '.join(objects[0].split()[1:])
        else:
            obj0 = ' '.join(objects[0].split()[2:])
        obj1 = ' '.join(objects[1].split()[1:])
        program += [
            {"op": "filter", "concept": [obj0.strip(), 'none'], "inputs": [0]},
            {"op": "filter", "concept": [obj1.strip(), 'none'], "inputs": [0]},
            {'op': 'binaryEBM', "concept": ["inside", "false"], 'inputs': [1, 2]}
        ]
    elif utterance.startswith('push'):
        objects = utterance.split('into')
        obj0 = ' '.join(objects[0].split()[2:])
        obj1 = ' '.join(objects[1].split()[1:])
        program += [
            {"op": "filter", "concept": [obj0.strip(), 'none'], "inputs": [0]},
            {"op": "filter", "concept": [obj1.strip(), 'none'], "inputs": [0]},
            {'op': 'binaryEBM', "concept": ["inside", "false"], 'inputs': [1, 2]}
        ]
    elif utterance.startswith('move'):
        objects = utterance.split('to')
        obj0 = ' '.join(objects[0].split()[2:])
        obj1 = ' '.join(objects[1].split()[1:])
        program += [
            {"op": "filter", "concept": [obj0.strip(), 'none'], "inputs": [0]},
            {"op": "filter", "concept": [obj1.strip(), 'none'], "inputs": [0]},
            {'op': 'binaryEBM', "concept": ["inside", "false"], 'inputs': [1, 2]}
        ]
    
    elif utterance.startswith('rearrange'):
        subutt = re.split(' into | circle | line | tower | square | triangle | rearrange', utterance)
        obj = subutt[0].split("rearrange")[-1].strip()
        program += [
            {"op": "filter", "concept": [obj, 'none'], "inputs": [0]},
        ]
        depth = 1
        shape = subutt[1].split('a')[-1].strip()
        program += [
            {"op": "multiAryEBM", "concept": [shape.strip(), None, None, "false"], "inputs": [depth]}
        ]

    else:
        assert False, utterance
    return program


def merge_programs(programs):
    merged_program = []
    merged_program.append(programs[0][0])

    filters = []
    actions = []
    for program in programs:
        for prog in program:
            if prog['op'] == 'detect_objects':
                continue
            elif prog['op'] == 'filter':
                if prog not in filters:
                    filters.append(prog)
            else:
                try:
                    pick, place = prog['inputs']
                except Exception as e:
                    st()
                pick_filter = program[pick]
                pick_index = filters.index(pick_filter) + 1
                
                place_filter = program[place]
                place_index = filters.index(place_filter) + 1

                prog['inputs'] = [pick_index, place_index]
                actions.append(prog)
                
    merged_program.extend(filters)
    merged_program.extend(actions)

    return merged_program

if __name__=="__main__":
    sent = "rearrange red cubes into a circle"
    program = utterance2program_bdetr(sent, cliport=False)
    print(program)
    print(sent)
    # sents = pickle.load(open('/projects/""/ns_transporter_data/transporter_data_sep_100d/composition-seen-colors-train/ebm/lang_relations_100.pickle', 'rb'))
    # for utterance in sents:
    #     if utterance.startswith('done'):
    #         continue
    #     program = utterance2program_bdetr(utterance, cliport=False)
    #     print(utterance)
    #     print(program)